#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (c) 2016-2018, Niklas Hauser
# Copyright (c) 2017, Fabian Greif
# Copyright (c) 2020-2023, Christopher Durand
#
# This file is part of the modm project.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
# -----------------------------------------------------------------------------

import copy
from collections import defaultdict, OrderedDict

def port_ranges(gpios):
    ports = {p: (32, 0) for p in set(p["port"] for p in gpios)}
    for gpio in gpios:
        pin = int(gpio["pin"])
        pmin, pmax = ports[gpio["port"]]
        ports[gpio["port"]] = (min(pin, pmin), max(pin, pmax))

    ports = [{"name": k.upper(), "start": v[0], "width": v[1] - v[0] + 1} for k,v in ports.items()]
    ports.sort(key=lambda p: p["name"])
    return ports

def translate(s):
    return s.replace("_", "").capitalize()

def get_driver(s):
    name = None
    if "driver" in s: name = translate(s["driver"]);
    if "instance" in s: name += translate(s["instance"]);
    return name

def get_name(s):
    return translate(s["name"])

def extract_signals(signals):
    afs = {}
    for s in signals:
        driver = get_driver(s)
        if driver is None: continue
        name = get_name(s)
        key = driver + name
        afs[key] = {"driver": driver, "name": name, "af": s["af"]}
    return afs

def print_remap_group_table(signal_map, peripheral):
    per_signal_map = defaultdict(lambda : defaultdict(list))
    afs = []
    for key, signals in signal_map.items():
        for signal in signals:
            if get_driver(signal) == peripheral:
                for af in signal["af"]:
                    per_signal_map[key][af].append(get_name(signal))
                afs.extend(signal["af"])
    afs = sorted(list(set(afs)))

    array = [[peripheral] + afs]
    lengths = [len(s) for s in array[0]]
    for gpio in sorted(per_signal_map, key=lambda p: (p[0], int(p[1:]))):
        blub = [gpio.upper()] + [[] for _ in afs]
        for af, names in per_signal_map[gpio].items():
            for name in names:
                blub[array[0].index(af)].append(name)
        blub = [",".join(b) if isinstance(b, list) else b for b in blub]
        lengths = [max(l, len(b)) for l, b in zip(lengths, blub)]
        array.append(blub)

    lines = []
    for i, vals in enumerate(array):
        line = []
        for j, text in enumerate(vals):
            line.append(text.center(lengths[j]+2))
        lines.append("|".join(line))
        if i == 0:
            lines.append("|".join(["-" * (lengths[j]+2) for j in range(len(array[0]))]))

    return lines


def validate_alternate_functions(driver, env):
    """
    For the STM32F1/L1 series
    -------------------------
    These chips can only remap functions in groups, some of them with overlapping
    pins. A unique map must exist for (Signal.driver + Signal.instance + ordered_set(Pins)) -> Remap
    So the combination of Signal driver plus instance and an ordered set of Pins must not map to
    different Remap Groups, otherwise our HAL assumptions are broken.
    """
    success = True

    if "f1" in driver["type"]:
        for remap in driver["remap"]:
            af_map = defaultdict(list)
            for group in remap["group"]:
                for signal in group["signal"]:
                    af_map[signal["port"] + signal["pin"]].append((group["id"], signal["name"]))
            for pin, afs in af_map.items():
                per = remap["driver"]
                if "instance" in remap: per += remap["instance"];
                gafs = sorted(list(set([a[0] for a in afs])))
                nafs = sorted(list(set([a[1] for a in afs])))
                if len(afs) > 1:
                    env.log.debug("P%s: Duplicate groups %s for signal '%s': %s" % (pin.upper(), gafs, per, nafs))
                    success = False
    else:
        for gpio in driver["gpio"]:
            if "signal" not in gpio: continue;
            af_map = defaultdict(list)
            for signal in gpio["signal"]:
                if "af" not in signal: continue;
                af_map[get_driver(signal)].append((signal["af"], signal["name"]))
            for per, afs in af_map.items():
                nafs = sorted(list(set([a[0] for a in afs])))
                if len(nafs) > 1:
                    env.log.debug("P%s: Duplicate AFs %s for signal '%s': %s" % (gpio["port"].upper() + gpio["pin"], nafs, per, [a[1] for a in afs if a[0] in nafs]))
                    success = False
    if not success:
        env.log.debug("Gpio signal validation failed!")

def get_remap_command(family, key):
    reg = 'SYSCFG->CFGR1'
    mask = {
        ('c0', 'a9') : 'SYSCFG_CFGR1_PA11_RMP',
        ('c0', 'a10'): 'SYSCFG_CFGR1_PA12_RMP',
        ('c0', 'a11'): 'SYSCFG_CFGR1_PA11_RMP',
        ('c0', 'a12'): 'SYSCFG_CFGR1_PA12_RMP',
        ('f0', 'a9') : 'SYSCFG_CFGR1_PA11_PA12_RMP',
        ('f0', 'a10'): 'SYSCFG_CFGR1_PA11_PA12_RMP',
        ('f0', 'a11'): 'SYSCFG_CFGR1_PA11_PA12_RMP',
        ('f0', 'a12'): 'SYSCFG_CFGR1_PA11_PA12_RMP',
        ('g0', 'a9') : 'SYSCFG_CFGR1_PA11_RMP',
        ('g0', 'a10'): 'SYSCFG_CFGR1_PA12_RMP',
        ('g0', 'a11'): 'SYSCFG_CFGR1_PA11_RMP',
        ('g0', 'a12'): 'SYSCFG_CFGR1_PA12_RMP',
    }.get((family, key))
    if mask is None:
        raise ValidateException("Unknown Remap for GPIO: 'P{}'".format(key.upper()))
    return (reg, mask)

bprops = {}

# -----------------------------------------------------------------------------
def init(module):
    module.name = ":platform:gpio"
    module.description = FileReader("module.md")

def prepare(module, options):
    device = options[":target"]
    if not device.has_driver("gpio:stm32*"):
        return False

    bprops["ranges"] = port_ranges(device.get_driver("gpio")["gpio"])
    bprops["ports"] = OrderedDict([(k, i) for i, k in enumerate([p["name"] for p in bprops["ranges"]])])

    module.add_set_option(
        EnumerationOption(
            name="enable_ports",
            description="Enable clock for these GPIO ports during startup",
            enumeration=list(bprops["ports"].keys())),
        default=list(bprops["ports"].keys()))

    module.depends(
        ":architecture:gpio",
        ":cmsis:device",
        ":math:utils",
        ":platform:rcc")

    module.add_query(EnvironmentQuery(name="all_signals", factory=lambda env: bprops["all_signals"]))

    return True

def validate(env):
    device = env[":target"]
    driver = device.get_driver("gpio")
    # Not an issue anymore, but kept here for future issues
    # validate_alternate_functions(driver, env)

    signal_map = defaultdict(list)
    if "f1" in driver["type"]:
        # map all remaps onto pins
        for remap in driver["remap"]:
            # smap = defaultdict(list)
            for group in remap["group"]:
                for signal in group["signal"]:
                    key = signal["port"] + signal["pin"]
                    for sig in signal_map[key]:
                        if (get_driver(sig) == get_driver(remap) and
                            get_name(sig) == get_name(signal)):
                            sig["af"].append(group["id"])
                            break
                    else:
                        sig = copy.deepcopy(signal)
                        sig["driver"] = remap["driver"]
                        sig["af"] = [group["id"]]
                        if "instance" in remap:
                            sig["instance"] = remap["instance"]
                        signal_map[key].append(sig)
        bprops["group_map"] = signal_map
        bprops["remaps"] = driver["remap"]

    # Compute the set of remapped pins
    remapped_gpios = {}
    for p in driver["package"][0]["pin"]:
        variant = p.get("variant", "")
        if "remap" in variant:  # also matches "remap-default"
            name = p["name"][1:4].strip().lower()
            if len(name) > 2 and not name[2].isdigit():
                name = name[:2]
            remapped_gpios[name] = (variant == "remap") # "remap-default" -> False

    all_signals = {}
    for gpio in driver["gpio"]:
        key = gpio["port"] + gpio["pin"]
        raw_signals = gpio["signal"] if "signal" in gpio else []
        for signal in raw_signals:
            if "af" in signal:
                signal["af"] = [signal["af"]]
            else:
                signal["af"] = []
        if key in signal_map:
            raw_signals.extend(signal_map[key])
        gpio["signal"] = raw_signals
        extracted_signals = extract_signals(raw_signals)
        all_signals.update(extracted_signals)
        signal_names = sorted(list(set(s["name"] for s in extracted_signals.values())))
        extracted_signals = OrderedDict([(name, sorted([s for s in extracted_signals.values() if s["name"] == name], key=lambda si:si["name"])) for name in signal_names])
        has_remap = key in remapped_gpios
        bprops[key] = {
            "gpio": gpio,
            "signals": extracted_signals,
            "has_remap": has_remap,
        }
        if has_remap:
            reg, mask = get_remap_command(device.identifier.family, key)
            bprops[key].update({
                "remap_reg": reg,
                "remap_mask": mask,
                "remap_value": remapped_gpios[key]
            })

    bprops["all_signals"] = sorted(list(set(s["name"] for s in all_signals.values())))
    bprops["gpios"] = [bprops[g["port"] + g["pin"]] for g in driver["gpio"]]
    bprops["port_width"] = 16

    # Check the max number of ADC instances
    max_adc_instance = max(map(int, device.get_driver("adc").get("instance", [1])))
    if (max_adc_instance > (3 if device.identifier["family"] == "f1" else 5 if device.identifier["family"] == "g4" else 4)):
        raise ValidateException("Too many ADC instances: '{}'".format(max_adc_instance))

def build(env):
    env.substitutions["target"] = env[":target"].identifier
    env.substitutions.update(bprops)
    env.outbasepath = "modm/src/modm/platform/gpio"

    env.template("data.hpp.in", filters={"to_adc_channel": lambda name: "".join(filter(str.isdigit, name))})
    env.template("static.hpp.in")
    env.template("../common/pins.hpp.in", "pins.hpp")

    env.template("port.hpp.in")
    env.template("software_port.hpp.in")
    env.template("set.hpp.in")
    env.template("base.hpp.in")
    env.template("../common/unused.hpp.in", "unused.hpp")
    env.template("../common/port_shim.hpp.in", "port_shim.hpp")
    if len(env["enable_ports"]):
        env.template("enable.cpp.in")

    env.copy("../common/inverted.hpp", "inverted.hpp")
    env.template("../common/connector.hpp.in", "connector.hpp",
                 filters={"formatPeripheral": get_driver,
                          "printSignalMap": print_remap_group_table})
